"""
Given a pretrained contexteval model, extract the Elmo contextualizer
and compare each layer with a pre-trained AllenNLP Elmo contextualizer.
"""
import argparse
import logging
import os
import sys

from allennlp.models import load_archive
from allennlp.modules.elmo import Elmo
import torch

sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
import contexteval  # noqa:F401

logger = logging.getLogger(__name__)


DEFAULT_OPTIONS_PATH = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" # noqa:E501
DEFAULT_WEIGHT_PATH = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" # noqa:E501


def main():
    # Load generated model file
    archive = load_archive(args.archive_path)
    model = archive.model
    finetuned_elmo_state_dict = model._contextualizer._elmo.state_dict()

    # Load ELMo options and weights file
    elmo = Elmo(args.options_file, args.weight_file, 1)
    original_elmo_state_dict = elmo.state_dict()

    # Get the average parameter shift in the token embedder.
    token_embedder_total_shift = 0.0
    token_embedder_num_params = 0.0
    for key, parameter in finetuned_elmo_state_dict.items():
        if "token_embedder" in key:
            token_embedder_num_params += parameter.numel()
            token_embedder_total_shift += torch.abs(parameter - original_elmo_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in token embedder: {}".format(
        token_embedder_total_shift / token_embedder_num_params))

    # Get the average parameter shift in the first layer of the LSTM.
    layer_0_total_shift = 0.0
    layer_0_num_params = 0.0
    for key, parameter in finetuned_elmo_state_dict.items():
        if "backward_layer_0" in key or "forward_layer_0" in key:
            layer_0_num_params += parameter.numel()
            layer_0_total_shift += torch.abs(parameter - original_elmo_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in LSTM Layer 0: {}".format(layer_0_total_shift / layer_0_num_params))

    # Get the average parameter shift in the second layer of the LSTM.
    layer_1_total_shift = 0.0
    layer_1_num_params = 0.0
    for key, parameter in finetuned_elmo_state_dict.items():
        if "backward_layer_1" in key or "forward_layer_1" in key:
            layer_1_num_params += parameter.numel()
            layer_1_total_shift += torch.abs(parameter - original_elmo_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in LSTM Layer 1: {}".format(layer_1_total_shift / layer_1_num_params))

    # Print the scalar mix parameters of the fine-tuned model.
    normed_scalars = torch.nn.functional.softmax(torch.cat(
        [parameter for key, parameter in finetuned_elmo_state_dict.items()
         if "scalar_parameters" in key]), dim=0)
    normed_scalars = torch.split(normed_scalars, split_size_or_sections=1)
    normed_scalars = [normed_scalar.item() for normed_scalar in normed_scalars]
    logger.info("Normalized Scalar Mix of fine-tuned model: {}".format(normed_scalars))

    # Print the gamma
    logger.info("Gamma of fine-tuned model: {}".format(finetuned_elmo_state_dict["scalar_mix_0.gamma"].item()))


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s - %(levelname)s "
                        "- %(name)s - %(message)s",
                        level=logging.INFO)
    parser = argparse.ArgumentParser(
        description=("Given a path to a model with a fine-tuned ElmoContextualizer, "
                     "report the parameter shift for each layer as compared to a "
                     "pre-trained ELMo model."),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--archive-path", type=str,
                        help=("Path to a model.tar.gz file generated by AllenNLP."))
    parser.add_argument('--options-file', type=str, default=DEFAULT_OPTIONS_PATH,
                        help='The path to the ELMo options file.')
    parser.add_argument('--weight-file', type=str, default=DEFAULT_WEIGHT_PATH,
                        help='The path to the ELMo weight file.')
    args = parser.parse_args()
    main()
